{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "## Sentiment Classification on Large Movie Reviews\n",
    "\n",
    "[Sentiment Analysis](https://en.wikipedia.org/wiki/Sentiment_analysis) is understood as a classic natural language processing problem. In this example, a large moview review dataset was chosen from IMDB to do a sentiment classification task with some deep learning approaches. The labeled data set consists of 50,000 [IMDB](http://www.imdb.com/) movie reviews (good or bad), in which 25000 highly polar movie reviews for training, and 25,000 for testing. The dataset is originally collected by Stanford researchers and was used in a [2011 paper](http://ai.stanford.edu/~amaas/papers/wvSent_acl2011.pdf), and the highest accuray of 88.33% was achieved without using the unbalanced data. This example illustrates some deep learning approaches to do the sentiment classification with [BigDL](https://github.com/intel-analytics/BigDL) python API.\n",
    "\n",
    "### Load the IMDB Dataset\n",
    "The IMDB dataset need to be loaded into BigDL, note that the dataset has been pre-processed, and each review was encoded as a sequence of integers. Each integer represents the index of the overall frequency of dataset, for instance, '5' means the 5-th most frequent words occured in the data. It is very convinient to filter the words by some conditions, for example, to filter only the top 5,000 most common word and/or eliminate the top 30 most common words. Let's define functions to load the pre-processed data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processing text dataset\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "finished processing text\n"
     ]
    }
   ],
   "source": [
    "from bigdl.dataset import base\n",
    "import numpy as np\n",
    "\n",
    "def download_imdb(dest_dir):\n",
    "    \"\"\"Download pre-processed IMDB movie review data\n",
    "\n",
    "    :argument\n",
    "        dest_dir: destination directory to store the data\n",
    "\n",
    "    :return\n",
    "        The absolute path of the stored data\n",
    "    \"\"\"\n",
    "    file_name = \"imdb.npz\"\n",
    "    file_abs_path = base.maybe_download(file_name,\n",
    "                                        dest_dir,\n",
    "                                        'https://s3.amazonaws.com/text-datasets/imdb.npz')\n",
    "    return file_abs_path\n",
    "\n",
    "def load_imdb(dest_dir='/tmp/.bigdl/dataset'):\n",
    "    \"\"\"Load IMDB dataset.\n",
    "\n",
    "    :argument\n",
    "        dest_dir: where to cache the data (relative to `~/.bigdl/dataset`).\n",
    "\n",
    "    :return\n",
    "        the train, test separated IMDB dataset.\n",
    "    \"\"\"\n",
    "    path = download_imdb(dest_dir)\n",
    "    f = np.load(path, allow_pickle=True)\n",
    "    x_train = f['x_train']\n",
    "    y_train = f['y_train']\n",
    "    x_test = f['x_test']\n",
    "    y_test = f['y_test']\n",
    "    f.close()\n",
    "\n",
    "    return (x_train, y_train), (x_test, y_test)\n",
    "\n",
    "print('Processing text dataset')\n",
    "(x_train, y_train), (x_test, y_test) = load_imdb()\n",
    "print('finished processing text')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "In order to set a proper max sequence length, we need to go througth the property of the data and see the length distribution of each sentence in the dataset. A box and whisker plot is shown below for reviewing the length distribution in words."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\nReview length: \nMean 233.76 words (172.911495)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAFpCAYAAACVjP/1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAF8FJREFUeJzt3W9sHPWdx/HPN2vjpcF3GOEjCJOATtFpzUpHKwuQ8IOuTiLAE9InDVbUoNhqGgGrVI0U0/gBpFVCidRU1GqTBnlFKoWFSLRpFOACQitVhqMlnBD54ytENBaO0hBwRKmRE9v53gNPwpqS2LP+M1n/3i9ptbPfndn9rgT+ZOY38xtzdwEAwrMg6QYAAMkgAAAgUAQAAASKAACAQBEAABAoAgAAAkUAAECgCAAACBQBAACBIgAAIFA1STdwOddff73fcsstSbcBAFXlnXfe+cTdGydb74oOgFtuuUUHDx5Mug0AqCpm1j+V9TgEBACBIgAAIFAEAAAEigAAgEARAAAQKAIAAAJFAABAoAgAAAjUpAFgZjebWcnMjprZETNbF9WfMLMTZvZu9Li/bJsfm9kxM/uLmS0rq98b1Y6Z2WOz85MAAFMxlT2AUUnr3b1Z0l2SHjGz5ui9X7j77dHjZUmK3ntQ0m2S7pX0azNLmVlK0q8k3SepWVJb2ecAVaNYLCqbzSqVSimbzapYLCbdElCRSaeCcPeTkk5Gy5+bWZ+kmy6zyQOSnnf3s5L+ambHJN0RvXfM3T+UJDN7Plr36DT6B+ZUsVhUV1eXenp61Nraqt7eXnV0dEiS2traEu4OiCfWGICZ3SLpm5L+FJUeNbP3zKxgZg1R7SZJH5VtNhDVLlUHqsbmzZvV09OjXC6n2tpa5XI59fT0aPPmzUm3BsQ25QAws2skvSjph+7+d0nbJf27pNs1vofw85loyMzWmNlBMzt4+vTpmfhIYMb09fWptbV1Qq21tVV9fX0JdQRUbkoBYGa1Gv/jv9vdfydJ7n7K3cfc/bykZ/TlYZ4Tkm4u27wpql2qPoG773T3FndvaWycdDZTYE5lMhn19vZOqPX29iqTySTUEVC5qZwFZJJ6JPW5+7ay+o1lq31H0uFoeZ+kB82szsxulbRU0p8lvS1pqZndamZXaXygeN/M/AxgbnR1damjo0OlUkkjIyMqlUrq6OhQV1dX0q0BsU3lfgB3S/qepENm9m5U26jxs3hul+SSjkv6gSS5+xEz26Pxwd1RSY+4+5gkmdmjkg5ISkkquPuRGfwtwKy7MNCbz+fV19enTCajzZs3MwCMqmTunnQPl9TS0uLcEAYA4jGzd9y9ZbL1uBIYAAJFAABAoAgAAAgUAQAAgSIAACBQBAAABIoAAIBAEQAAECgCAAACRQAAQKAIAAAIFAEAAIEiAAAgUAQAAASKAABiKhaLymazSqVSymazKhaLSbcEVGQqN4QBECkWi+rq6lJPT49aW1vV29urjo4OSeKmMKg63BAGiCGbzaq7u1u5XO5irVQqKZ/P6/Dhw5fZEpg7U70hDAEAxJBKpTQ8PKza2tqLtZGREaXTaY2NjSXYGfAl7ggGzIJMJqPe3t4Jtd7eXmUymYQ6AipHAAAxdHV1qaOjQ6VSSSMjIyqVSuro6FBXV1fSrQGxMQgMxHBhoDefz6uvr0+ZTEabN29mABhViTEAAJhnGAMAAFwWAQAAgSIAACBQBAAABIoAAIBAEQAAECgCAAACRQAAQKAIAAAIFAEAAIEiAAAgUAQAAASKAACAQBEAABAoAgAAAkUAADEVi0Vls1mlUills1kVi8WkWwIqQgAAMRSLRa1bt05DQ0OSpKGhIa1bt44QQFUiAIAYNmzYoJqaGhUKBQ0PD6tQKKimpkYbNmxIujUgNgIAiGFgYEC7du1SLpdTbW2tcrmcdu3apYGBgaRbA2IjAAAgUAQAEENTU5NWrVqlUqmkkZERlUolrVq1Sk1NTUm3BsRGAAAxbN26VWNjY2pvb1ddXZ3a29s1NjamrVu3Jt0aEBsBAMTQ1tamp59+WgsXLpSZaeHChXr66afV1taWdGtAbObuSfdwSS0tLX7w4MGk2wCAqmJm77h7y2TrsQcAAIEiAAAgUAQAAASKAACAQE0aAGZ2s5mVzOyomR0xs3VR/Toze83MPoieG6K6mdkvzeyYmb1nZt8q+6yHovU/MLOHZu9nAQAmM5U9gFFJ6929WdJdkh4xs2ZJj0l63d2XSno9ei1J90laGj3WSNoujQeGpMcl3SnpDkmPXwgNAMDcmzQA3P2ku/9vtPy5pD5JN0l6QNKuaLVdkpZHyw9I+q2Pe0vStWZ2o6Rlkl5z90F3PyPpNUn3zuivAQBMWawxADO7RdI3Jf1J0g3ufjJ662+SboiWb5L0UdlmA1HtUnUAQAKmHABmdo2kFyX90N3/Xv6ej19NNiNXlJnZGjM7aGYHT58+PRMfCQD4GlMKADOr1fgf/93u/ruofCo6tKPo+eOofkLSzWWbN0W1S9UncPed7t7i7i2NjY1xfgsAIIapnAVkknok9bn7trK39km6cCbPQ5L+UFZfFZ0NdJekz6JDRQck3WNmDdHg7z1RDQCQgJoprHO3pO9JOmRm70a1jZJ+JmmPmXVI6pf03ei9lyXdL+mYpC8krZYkdx80s59Kejta7yfuPjgjvwIAEBuTwQHAPMNkcACAyyIAACBQBAAABIoAAIBAEQAAECgCAAACRQAAQKAIAAAIFAEAAIEiAAAgUAQAEFM+n1c6nZaZKZ1OK5/PJ90SUBECAIghn89rx44d2rJli4aGhrRlyxbt2LGDEEBVYjI4IIZ0Oq0tW7boRz/60cXatm3btHHjRg0PDyfYGfClqU4GRwAAMZiZhoaG9I1vfONi7YsvvtDChQt1Jf+/hLAwGygwC+rq6rRjx44JtR07dqiuri6hjoDKTeWGMAAi3//+99XZ2SlJWrt2rXbs2KHOzk6tXbs24c6A+AgAIIbu7m5J0saNG7V+/XrV1dVp7dq1F+tANWEMAADmGcYAAACXRQAAQKAIAAAIFAEAAIEiAAAgUAQAAASKAABiKhaLymazSqVSymazKhaLSbcEVIQLwYAYisWiurq61NPTo9bWVvX29qqjo0OS1NbWlnB3QDxcCAbEkM1m1d3drVwud7FWKpWUz+d1+PDhBDsDvsRsoMAsSKVSGh4eVm1t7cXayMiI0um0xsbGEuwM+BJXAgOzIJPJaNOmTRPGADZt2qRMJpN0a0BsBAAQQy6X01NPPaX29nZ9/vnnam9v11NPPTXhkBBQLQgAIIZSqaTOzk4VCgXV19erUCios7NTpVIp6daA2BgDAGJgDADVgDEAYBZkMhn19vZOqPX29jIGgKpEAAAxdHV1qaOjQ6VSSSMjIyqVSuro6FBXV1fSrQGxcSEYEMOFi73y+bz6+vqUyWS0efNmLgJDVWIMAADmGcYAAACXRQAAMTEZHOYLxgCAGJgMDvMJYwBADEwGh2rAZHDALOBCMFQDBoGBWcCFYJhPCAAgBi4Ew3zCIDAQAxeCYT5hDwAAAsUeABADp4FiPuEsICCGbDar5cuXa+/evRcPAV14zWmguFJM9Swg9gCAGI4ePaqhoSEVCoWLewDt7e3q7+9PujUgNgIAiOGqq67S3XffPWEQ+O6779bJkyeTbg2IjUFgIIazZ8/qhRdemHBP4BdeeEFnz55NujUgNgIAiKGurk4rVqyYcE/gFStWqK6uLunWgNgmDQAzK5jZx2Z2uKz2hJmdMLN3o8f9Ze/92MyOmdlfzGxZWf3eqHbMzB6b+Z8CzL5z587pjTfeUHd3t4aHh9Xd3a033nhD586dS7o1ILap7AE8K+ner6n/wt1vjx4vS5KZNUt6UNJt0Ta/NrOUmaUk/UrSfZKaJbVF6wJVpbm5WStXrlQ+n1c6nVY+n9fKlSvV3Mx/zqg+kwaAu/9R0uAUP+8BSc+7+1l3/6ukY5LuiB7H3P1Ddz8n6floXaCqdHV16bnnnpuwB/Dcc88xFQSq0nTOAnrUzFZJOihpvbufkXSTpLfK1hmIapL00Vfqd07ju4FEMBUE5pNKA2C7pJ9K8uj555LaZ6IhM1sjaY0kLV68eCY+EphRbW1t/MHHvFDRWUDufsrdx9z9vKRnNH6IR5JOSLq5bNWmqHap+td99k53b3H3lsbGxkraAwBMQUUBYGY3lr38jqQLZwjtk/SgmdWZ2a2Slkr6s6S3JS01s1vN7CqNDxTvq7xtAMB0TXoIyMyKkr4t6XozG5D0uKRvm9ntGj8EdFzSDyTJ3Y+Y2R5JRyWNSnrE3ceiz3lU0gFJKUkFdz8y478GADBlTAYHAPMMt4QEAFwWAQAAgSIAACBQBAAQU7FYVDabVSqVUjabVbFYTLoloCIEABBDsVjUunXrNDQ0JEkaGhrSunXrCAFUJQIAiGHDhg0aGRmRJF04g25kZEQbNmxIsi2gIgQAEMPAwIDS6bQKhYLOnj2rQqGgdDqtgYGBpFsDYiMAgJhyudyE6aBzuVzSLQEVIQCAmPbs2TPhlpB79uxJuiWgIgQAEENNTY3S6bS6u7t1zTXXqLu7W+l0WjU105lZHUgGAQDEMDY2pquvvlqSZGaSpKuvvlpjY2NJtgVUhAAAYmhublZra6tOnjyp8+fP6+TJk2ptbeWWkKhKBAAQQy6X0/79+7VlyxYNDQ1py5Yt2r9/PwPBqEoEABBDqVRSZ2enCoWC6uvrVSgU1NnZqVKplHRrQGxMBw3EkEqlNDw8rNra2ou1kZERpdNpxgFwxWA6aGAWZDIZbdq0acJcQJs2bVImk0m6NSA2AgCIIZfL6cknn9Qnn3wid9cnn3yiJ598kjEAVCUCAIhh7969SqfTGhwclLtrcHBQ6XRae/fuTbo1IDYCAIhhYGBA9fX1OnDggM6dO6cDBw6ovr6euYBQlQgAIKb169crl8uptrZWuVxO69evT7oloCIEABDTtm3bVCqVNDIyolKppG3btiXdElARJjABYmhqatI//vEPtbe3q7+/X0uWLNHw8LCampqSbg2IjT0AIIatW7devAbgwlxAtbW12rp1a5JtARUhAIAY2tratGLFiglzAa1YsUJtbW1JtwbERgAAMRSLRb300kt65ZVXdO7cOb3yyit66aWXuCcwqhJTQQAxZLNZLV++XHv37lVfX58ymczF14cPH066PUDS1KeCYBAYiOHo0aM6deqUrrnmGrm7hoaG9Jvf/Eaffvpp0q0BsXEICIghlUppbGxswk3hx8bGlEqlkm4NiI0AAGIYHR1VXV3dhFpdXZ1GR0cT6gioHAEAxLR69Wrl83ml02nl83mtXr066ZaAijAGAMTQ1NSkZ599Vrt371Zra6t6e3u1cuVKLgRDVWIPAIhh69atGh0dVXt7u9LptNrb2zU6OsqFYKhKBAAQQ1tbmxYtWqTjx4/r/PnzOn78uBYtWsSFYKhKBAAQw7Jly3To0CE1NDRowYIFamho0KFDh7Rs2bKkWwNiIwCAGF599VXV19frxRdf1PDwsF588UXV19fr1VdfTbo1IDYCAIhp9+7dE+4HsHv37qRbAipCAAAx7d+//7KvgWpBAAAxLFy4UDt37tTDDz+szz77TA8//LB27typhQsXJt0aEBsBAMTwzDPPKJVKafv27br22mu1fft2pVIpPfPMM0m3BsRGAAAxvPnmm3J3LVq0SAsWLNCiRYvk7nrzzTeTbg2IjemggRjS6bSWLFmiDz74QO4uM9PSpUvV39+v4eHhpNsDJDEdNDArzp49q/fff18LFiy4GADvv/9+0m0BFeEQEFCB8+fPT3gGqhEBAFTgtttuU39/v2677bakWwEqxiEgIKaamhodOXJES5Ysufia+wGgGrEHAMT01T/2/PFHtSIAACBQBAAABIoAAIBAEQAAECgCAAACNWkAmFnBzD42s8NltevM7DUz+yB6bojqZma/NLNjZvaemX2rbJuHovU/MLOHZufnAACmaip7AM9Kuvcrtcckve7uSyW9Hr2WpPskLY0eayRtl8YDQ9Ljku6UdIekxy+EBgAgGZMGgLv/UdLgV8oPSNoVLe+StLys/lsf95aka83sRknLJL3m7oPufkbSa/rnUAEAzKFKxwBucPeT0fLfJN0QLd8k6aOy9Qai2qXq/8TM1pjZQTM7ePr06QrbAwBMZtqDwD4+n/SMzSnt7jvdvcXdWxobG2fqYwEAX1FpAJyKDu0oev44qp+QdHPZek1R7VJ1AEBCKg2AfZIunMnzkKQ/lNVXRWcD3SXps+hQ0QFJ95hZQzT4e09UAwAkZNLZQM2sKOnbkq43swGNn83zM0l7zKxDUr+k70arvyzpfknHJH0habUkufugmf1U0tvRej9x968OLAMA5hC3hARiMLNLvncl/7+EsEz1lpBcCQwAgSIAACBQBAAABIoAAIBAEQAAECgCAAACRQAAQKAIAAAIFAEAAIEiAAAgUAQAAASKAACAQBEAABAoAgAAAkUAAECgCAAACBQBAACBIgAAIFAEAAAEigAAgEARAAAQKAIAAAJFAABAoAgAAAgUAQAAgSIAACBQBAAABIoAAIBAEQAAECgCAAACRQAAQKAIAAAIFAEAAIEiAAAgUAQAAASKAACAQBEAABAoAgAAAkUAAECgCAAACBQBAACBIgAAIFAEAAAEigAAgEARAAAQKAIAAAJFAABAoAgAAAgUAQAAgSIAACBQ0woAMztuZofM7F0zOxjVrjOz18zsg+i5Iaqbmf3SzI6Z2Xtm9q2Z+AEAgMrMxB5Azt1vd/eW6PVjkl5396WSXo9eS9J9kpZGjzWSts/AdwMAKjQbh4AekLQrWt4laXlZ/bc+7i1J15rZjbPw/UBsZjalx3Q/A7iSTDcAXNKrZvaOma2Jaje4+8lo+W+SboiWb5L0Udm2A1ENSJy7T+kx3c8AriQ109y+1d1PmNm/SXrNzP6v/E13dzOL9V99FCRrJGnx4sXTbA8AcCnT2gNw9xPR88eSfi/pDkmnLhzaiZ4/jlY/Ienmss2botpXP3Onu7e4e0tjY+N02gNm3KX+Fc+/7lGNKg4AM1toZvUXliXdI+mwpH2SHopWe0jSH6LlfZJWRWcD3SXps7JDRUDVKD+cw6EdVLPpHAK6QdLvo4GtGknPuft/m9nbkvaYWYekfknfjdZ/WdL9ko5J+kLS6ml8NwBgmioOAHf/UNJ/fk39U0n/9TV1l/RIpd8HAJhZXAkMAIEiAAAgUAQAAASKAACAQBEAABAoAgAAAkUAAECgCAAACBQBAACBIgAAIFAEAAAEigAAgEARAAAQKAIAAAJFAABAoAgAAAgUAQAAgSIAACBQBAAABIoAAIBAEQAAEKiapBsAZsN1112nM2fOzPr3mNmsfn5DQ4MGBwdn9TsQLgIA89KZM2fk7km3MW2zHTAIG4eAACBQBAAABIoAAIBAEQAAECgCAAACRQAAQKAIAAAIFAEAAIEiAAAgUAQAAASKAACAQDEXEOYlf/xfpCf+Nek2ps0f/5ekW8A8RgBgXrJNf583k8H5E0l3gfmKQ0AAECgCAAACRQAAQKAYA8C8NR9uptLQ0JB0C5jHCADMS3MxAGxm82KgGeHiEBAABIoAAIBAEQAAECgCAAACRQAAQKAIAAAIFAEAAIEiAAAgUAQAAASKAACAQM15AJjZvWb2FzM7ZmaPzfX3AwDGzWkAmFlK0q8k3SepWVKbmTXPZQ8AgHFzvQdwh6Rj7v6hu5+T9LykB+a4BwCA5n420JskfVT2ekDSnXPcA/BPKp06Ou52zB6KK8kVNx20ma2RtEaSFi9enHA3CAV/mBGiuT4EdELSzWWvm6LaRe6+091b3L2lsbFxTpsDgJDMdQC8LWmpmd1qZldJelDSvjnuAQCgOT4E5O6jZvaopAOSUpIK7n5kLnsAAIyb8zEAd39Z0stz/b0AgIm4EhgAAkUAAECgCAAACBQBAACBIgAAIFAEAAAEigAAgEARAAAQKAIAAAJlV/IsiGZ2WlJ/0n0Al3C9pE+SbgL4GkvcfdLZNK/oAACuZGZ20N1bku4DqBSHgAAgUAQAAASKAAAqtzPpBoDpYAwAAALFHgAABIoAAGIys4KZfWxmh5PuBZgOAgCI71lJ9ybdBDBdBAAQk7v/UdJg0n0A00UAAECgCAAACBQBAACBIgAAIFAEABCTmRUl/Y+k/zCzATPrSLonoBJcCQwAgWIPAAACRQAAQKAIAAAIFAEAAIEiAAAgUAQAAASKAACAQBEAABCo/wej+ZYTcGUFjwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x432 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib\n",
    "matplotlib.use('Agg')\n",
    "%pylab inline \n",
    "# Summarize review length\n",
    "from matplotlib import pyplot\n",
    "\n",
    "print(\"Review length: \")\n",
    "X = np.concatenate((x_train, x_test), axis=0)\n",
    "result = [len(x) for x in X]\n",
    "print(\"Mean %.2f words (%f)\" % (np.mean(result), np.std(result)))\n",
    "# plot review length\n",
    "# Create a figure instance\n",
    "fig = pyplot.figure(1, figsize=(6, 6))\n",
    "pyplot.boxplot(result)\n",
    "pyplot.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "Looking the box and whisker plot, the max length of a sample in words is 500, and the mean and median are below 250. According to the plot, we can probably cover the mass of the distribution with a clipped length of 400 to 500. Here we set the max sequence length of each sample as 500."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "The corresponding vocabulary sorted by frequency is also required, for further embedding the words with pre-trained vectors. The downloaded vocabulary is in {word: index}, where each word as a key and the index as a value. It needs to be transformed into {index: word} format.\n",
    "\n",
    "Let's define a function to obtain the vocabulary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processing vocabulary\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "finished processing vocabulary\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "\n",
    "def get_word_index(dest_dir='/tmp/.bigdl/dataset', ):\n",
    "    \"\"\"Retrieves the dictionary mapping word indices back to words.\n",
    "\n",
    "    :argument\n",
    "        path: where to cache the data (relative to `~/.bigdl/dataset`).\n",
    "\n",
    "    :return\n",
    "        The word index dictionary.\n",
    "    \"\"\"\n",
    "    file_name = \"imdb_word_index.json\"\n",
    "    path = base.maybe_download(file_name,\n",
    "                               dest_dir,\n",
    "                               source_url='https://s3.amazonaws.com/text-datasets/imdb_word_index.json')\n",
    "    f = open(path)\n",
    "    data = json.load(f)\n",
    "    f.close()\n",
    "    return data\n",
    "\n",
    "print('Processing vocabulary')\n",
    "word_idx = get_word_index()\n",
    "idx_word = {v:k for k,v in word_idx.items()}\n",
    "print('finished processing vocabulary')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "### Text pre-processing\n",
    "\n",
    "Before we train the network, some pre-processing steps need to be applied to the dataset. \n",
    "\n",
    "Next let's go through the mechanisms that used to be applied to the data.\n",
    "\n",
    "* We insert a `start_char` at the beginning of each sentence to mark the start point. We set it as `2` here, and each other word index will plus a constant `index_from` to differentiate some 'helper index' (eg. `start_char`, `oov_char`, etc.).\n",
    "\n",
    "* A `max_words` variable is defined as the maximum index number (the least frequent word) included in the sequence. If the word index number is larger than `max_words`, it will be replaced by a out-of-vocabulary number `oov_char`, which is `3` here.\n",
    "\n",
    "* Each word index sequence is restricted to the same length. We used left-padding here, which means the right (end) of the sequence will be keep as many as possible and drop the left (head) of the sequence if its length is more than pre-defined `sequence_len`, or padding the left (head) of the sequence with `padding_value`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "start transformation\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "finish transformation\n"
     ]
    }
   ],
   "source": [
    "def replace_oov(x, oov_char, max_words):\n",
    "    \"\"\"\n",
    "    Replace the words out of vocabulary with `oov_char`\n",
    "    :param x: a sequence\n",
    "    :param max_words: the max number of words to include\n",
    "    :param oov_char: words out of vocabulary because of exceeding the `max_words`\n",
    "        limit will be replaced by this character\n",
    "\n",
    "    :return: The replaced sequence\n",
    "    \"\"\"\n",
    "    return [oov_char if w >= max_words else w for w in x]\n",
    "\n",
    "def pad_sequence(x, fill_value, length):\n",
    "    \"\"\"\n",
    "    Pads each sequence to the same length\n",
    "    :param x: a sequence\n",
    "    :param fill_value: pad the sequence with this value\n",
    "    :param length: pad sequence to the length\n",
    "\n",
    "    :return: the padded sequence\n",
    "    \"\"\"\n",
    "    if len(x) >= length:\n",
    "        return x[(len(x) - length):]\n",
    "    else:\n",
    "        return [fill_value] * (length - len(x)) + x\n",
    "\n",
    "def to_sample(features, label):\n",
    "    \"\"\"\n",
    "    Wrap the `features` and `label` to a training sample object\n",
    "    :param features: features of a sample\n",
    "    :param label: label of a sample\n",
    "    \n",
    "    :return: a sample object including features and label\n",
    "    \"\"\"\n",
    "    return Sample.from_ndarray(np.array(features, dtype='float'), np.array(label))\n",
    "\n",
    "padding_value = 1\n",
    "start_char = 2\n",
    "oov_char = 3\n",
    "index_from = 3\n",
    "max_words = 5000\n",
    "sequence_len = 500\n",
    "\n",
    "print('start transformation')\n",
    "\n",
    "from zoo.common.nncontext import *\n",
    "sc = init_nncontext(\"Sentiment Analysis Example\")\n",
    "\n",
    "\n",
    "train_rdd = sc.parallelize(zip(x_train, y_train), 2) \\\n",
    "    .map(lambda record: ([start_char] + [w + index_from for w in record[0]], record[1])) \\\n",
    "    .map(lambda record: (replace_oov(record[0], oov_char, max_words), record[1])) \\\n",
    "    .map(lambda record: (pad_sequence(record[0], padding_value, sequence_len), record[1])) \\\n",
    "    .map(lambda record: to_sample(record[0], record[1]))\n",
    "test_rdd = sc.parallelize(zip(x_test, y_test), 2) \\\n",
    "    .map(lambda record: ([start_char] + [w + index_from for w in record[0]], record[1])) \\\n",
    "    .map(lambda record: (replace_oov(record[0], oov_char, max_words), record[1])) \\\n",
    "    .map(lambda record: (pad_sequence(record[0], padding_value, sequence_len), record[1])) \\\n",
    "    .map(lambda record: to_sample(record[0], record[1]))\n",
    "        \n",
    "print('finish transformation')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "source": [
    "### Word Embedding\n",
    "\n",
    "[Word embedding](https://en.wikipedia.org/wiki/Word_embedding) is a recent breakthrough in natural language field. The key idea is to encode words and phrases into distributed representations in the format of word vectors, which means each word is represented as a vector. There are two widely used word vector training alogirhms, one is published by Google called [word to vector](https://arxiv.org/abs/1310.4546), the other is published by Standford called [Glove](https://nlp.stanford.edu/projects/glove/). In this example, pre-trained glove is loaded into a lookup table and will be fine-tuned during the training process. BigDL provides a method to download and load glove in `news20` package."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading glove\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "finish loading glove\n"
     ]
    }
   ],
   "source": [
    "from bigdl.dataset import news20\n",
    "import itertools\n",
    "\n",
    "embedding_dim = 100\n",
    "\n",
    "print('loading glove')\n",
    "glove = news20.get_glove_w2v(source_dir='/tmp/.bigdl/dataset', dim=embedding_dim)\n",
    "print('finish loading glove')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "For each word whose index less than the `max_word` should try to match its embedding and store in an array.\n",
    "\n",
    "With regard to those words which can not be found in glove, we randomly sample it from a [-0.05, 0.05] uniform distribution.\n",
    "\n",
    "BigDL usually use a `LookupTable` layer to do word embedding, so the matrix will be loaded to the LookupTable by seting the weight."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "processing glove\nfinish processing glove\n"
     ]
    }
   ],
   "source": [
    "print('processing glove')\n",
    "w2v = [glove.get(idx_word.get(i - index_from), np.random.uniform(-0.05, 0.05, embedding_dim))\n",
    "        for i in range(1, max_words + 1)]\n",
    "w2v = np.array(list(itertools.chain(*np.array(w2v, dtype='float'))), dtype='float') \\\n",
    "        .reshape([max_words, embedding_dim])\n",
    "print('finish processing glove')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "### Build models\n",
    "\n",
    "Next, let's build some deep learning models for the sentiment classification. \n",
    "\n",
    "As an example, several deep learning models are illustrated for tutorial, comparison and demonstration.\n",
    "\n",
    "**LSTM**, **GRU**, **Bi-LSTM**, **CNN** and **CNN + LSTM** models are implemented as options. To decide which model to use, just assign model_type the corresponding string."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "from bigdl.nn.layer import *\n",
    "\n",
    "p = 0.2\n",
    "\n",
    "def build_model(w2v):\n",
    "    model = Sequential()\n",
    "\n",
    "    embedding = LookupTable(max_words, embedding_dim)\n",
    "    embedding.set_weights([w2v])\n",
    "    model.add(embedding)\n",
    "    if model_type.lower() == \"gru\":\n",
    "        model.add(Recurrent()\n",
    "                .add(GRU(embedding_dim, 128, p))) \\\n",
    "            .add(Select(2, -1))\n",
    "    elif model_type.lower() == \"lstm\":\n",
    "        model.add(Recurrent()\n",
    "                  .add(LSTM(embedding_dim, 128, p)))\\\n",
    "            .add(Select(2, -1))\n",
    "    elif model_type.lower() == \"bi_lstm\":\n",
    "        model.add(BiRecurrent(CAddTable())\n",
    "                  .add(LSTM(embedding_dim, 128, p)))\\\n",
    "            .add(Select(2, -1))\n",
    "    elif model_type.lower() == \"cnn\":\n",
    "        model.add(Transpose([(2, 3)]))\\\n",
    "            .add(Dropout(p))\\\n",
    "            .add(Reshape([embedding_dim, 1, sequence_len]))\\\n",
    "            .add(SpatialConvolution(embedding_dim, 128, 5, 1))\\\n",
    "            .add(ReLU())\\\n",
    "            .add(SpatialMaxPooling(sequence_len - 5 + 1, 1, 1, 1))\\\n",
    "            .add(Reshape([128]))\n",
    "    elif model_type.lower() == \"cnn_lstm\":\n",
    "        model.add(Transpose([(2, 3)]))\\\n",
    "            .add(Dropout(p))\\\n",
    "            .add(Reshape([embedding_dim, 1, sequence_len])) \\\n",
    "            .add(SpatialConvolution(embedding_dim, 64, 5, 1)) \\\n",
    "            .add(ReLU()) \\\n",
    "            .add(SpatialMaxPooling(4, 1, 1, 1)) \\\n",
    "            .add(Squeeze(3)) \\\n",
    "            .add(Transpose([(2, 3)])) \\\n",
    "            .add(Recurrent()\n",
    "                 .add(LSTM(64, 128, p))) \\\n",
    "            .add(Select(2, -1))\n",
    "\n",
    "    model.add(Linear(128, 100))\\\n",
    "        .add(Dropout(0.2))\\\n",
    "        .add(ReLU())\\\n",
    "        .add(Linear(100, 1))\\\n",
    "        .add(Sigmoid())\n",
    "\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "### Optimization\n",
    "`Optimizer` need to be created to optimise the model.\n",
    "\n",
    "Here we use the `CNN` model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "creating: createSequential\ncreating: createLookupTable\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "creating: createRecurrent\ncreating: createTanh\ncreating: createSigmoid\ncreating: createGRU\ncreating: createSelect\ncreating: createLinear\ncreating: createDropout\ncreating: createReLU\ncreating: createLinear\ncreating: createSigmoid\ncreating: createBCECriterion\ncreating: createMaxEpoch\ncreating: createAdam\ncreating: createDistriOptimizer\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "creating: createEveryEpoch\ncreating: createTop1Accuracy\n"
     ]
    }
   ],
   "source": [
    "from bigdl.optim.optimizer import *\n",
    "from bigdl.nn.criterion import *\n",
    "\n",
    "# max_epoch = 4\n",
    "max_epoch = 1\n",
    "batch_size = 64\n",
    "model_type = 'gru'\n",
    "\n",
    "\n",
    "optimizer = Optimizer(\n",
    "        model=build_model(w2v),\n",
    "        training_rdd=train_rdd,\n",
    "        criterion=BCECriterion(),\n",
    "        end_trigger=MaxEpoch(max_epoch),\n",
    "        batch_size=batch_size,\n",
    "        optim_method=Adam())\n",
    "\n",
    "optimizer.set_validation(\n",
    "        batch_size=batch_size,\n",
    "        val_rdd=test_rdd,\n",
    "        trigger=EveryEpoch(),\n",
    "        val_method=Top1Accuracy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "To make the training process be visualized by TensorBoard, training summaries should be saved as a format of logs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "creating: createTrainSummary\ncreating: createSeveralIteration\ncreating: createValidationSummary\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<bigdl.optim.optimizer.Optimizer at 0x7fdd52119fd0>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import datetime as dt\n",
    "\n",
    "logdir = '/tmp/.bigdl/'\n",
    "app_name = 'adam-' + dt.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
    "\n",
    "train_summary = TrainSummary(log_dir=logdir, app_name=app_name)\n",
    "train_summary.set_summary_trigger(\"Parameters\", SeveralIteration(50))\n",
    "val_summary = ValidationSummary(log_dir=logdir, app_name=app_name)\n",
    "optimizer.set_train_summary(train_summary)\n",
    "optimizer.set_val_summary(val_summary)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "Now, let's start training!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimization Done.\nCPU times: user 178 ms, sys: 61.5 ms, total: 239 ms\nWall time: 37min 35s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "train_model = optimizer.optimize()\n",
    "print (\"Optimization Done.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "### Test\n",
    "Validation accuracy is shown in the training log, here let's get the accuracy on validation set by hand.\n",
    "\n",
    "Predict the `test_rdd` (validation set data), and obtain the predicted label and ground truth label in the list."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "predictions = train_model.predict(test_rdd)\n",
    "\n",
    "def map_predict_label(l):\n",
    "    if l > 0.5:\n",
    "        return 1\n",
    "    else:\n",
    "        return 0\n",
    "def map_groundtruth_label(l):\n",
    "    return l.to_ndarray()[0]\n",
    "\n",
    "y_pred = np.array([ map_predict_label(s) for s in predictions.collect()])\n",
    "\n",
    "y_true = np.array([map_groundtruth_label(s.label) for s in test_rdd.collect()])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "Then let's see the prediction accuracy on validation set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prediction accuracy on validation set is:  0.89312\n"
     ]
    }
   ],
   "source": [
    "correct = 0\n",
    "for i in range(0, y_pred.size):\n",
    "    if (y_pred[i] == y_true[i]):\n",
    "        correct += 1\n",
    "\n",
    "accuracy = float(correct) / y_pred.size\n",
    "print ('Prediction accuracy on validation set is: ', accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "Show the confusion matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.axes._subplots.AxesSubplot at 0x7fdd3275ab10>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAT0AAAD8CAYAAAAFWHM4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAF/5JREFUeJzt3XmUFdW1x/HvZhQQuhlkRkEljviMoBKcUGYUwSlBDRDCgkRN3osxDjGJvigqRoPDC5oQwQAqhqAG9KGIDApJQECQQZ7SQVFmtZkEE+ju/f64RXsLu+Hehu5L3/P7uM6i6tSp6lMLe7NPnRrM3RERCUWVTHdARKQiKeiJSFAU9EQkKAp6IhIUBT0RCYqCnogERUFPRIKioCciQVHQE5GgVCvvH7D3szV65KOSqtX8gkx3QQ5BwZ71Vpb90vmdrd7o+DL9jExSpiciQSn3TE9EKpmiwkz3oFwp6IlIXGFBpntQrhT0RCTGvSjTXShXCnoiElekoCciIVGmJyJB0USGiARFmZ6IhMQ1eysiQdFEhogERcNbEQmKJjJEJCjK9EQkKJrIEJGgaCJDRELirmt6IhISXdMTkaBoeCsiQVGmJyJBKdyb6R6UKwU9EYnT8FZEgqLhrYgERZmeiARFQU9EQuKayBCRoGT5Nb0qme6AiBxhiopSLwdhZmPNbIuZrUiqa2BmM8xsdfRn/ajezOxxM8szs2VmdlbSPoOi9qvNbFBSfXszWx7t87iZ2cH6pKAnInFelHo5uD8BPferuwOY6e5tgZnROkAvoG1UhgFPQiJIAncD5wLnAHfvC5RRm6FJ++3/s75GQU9E4g5jpufubwH5+1X3BcZFy+OAfkn14z1hPpBrZs2AHsAMd893963ADKBntK2eu893dwfGJx2rVLqmJyJx5X9Nr4m7b4yWNwFNouUWwCdJ7dZFdQeqX1dC/QEp6IlIXEHqLxE1s2EkhqL7jHb30anu7+5uZp5G7w6Zgp6IxKWR6UUBLuUgF9lsZs3cfWM0RN0S1a8HWiW1axnVrQc671c/J6pvWUL7A9I1PRGJO4zX9EoxFdg3AzsImJJUPzCaxe0IbI+GwdOB7mZWP5rA6A5Mj7btMLOO0aztwKRjlUqZnojEHcZremY2kUSW1sjM1pGYhR0BTDKzIcBa4NtR82lAbyAP2A0MBnD3fDO7F1gYtbvH3fdNjtxIYoa4FvBqVA5IQU9E4g7jY2jufm0pm7qU0NaBm0o5zlhgbAn1i4DT0+mTgp6IxGX5ExkKeiISl8bsbWWkoCcicV6hd5BUOAU9EYnTq6VEJCgKeiISFE1kiEhQCgsz3YNypaAnInEa3opIUBT0RCQouqYnIiHxIt2nJyIh0fBWRIKi2VsRCUqWZ3p6iSjwy/tHcuGl/en33R8W102fNZe+1/+Aduf3ZsWqD762z8ZNWzi76xU8/dzk4rodO7/g5l8Mp8+1Q+lz3TCWrliV0rHk8Pnj6N+yYd27LF0ys7iufv1cXps2kVUr5/HatInk5uYA0KdPd95ZPINFC19n/j+mcV6ns4v3eeD+O1m6ZCZLl8zkmmsur/DzyKjyf4loRinoAf16d+P3I4fH6k48/jgevf9XtD+z5Fd1/eZ/RnNBxw6xuhGP/p7zzu3AyxP/yIvjRnH8ca1SOpYcPuPHT+LSy66P1d1+203Mmj2PU047n1mz53H7bYlXts2aNY+z2nejw9ndGTrsFv7wh4cB6N2rC988sx3tO3Sn03mX8dObf0DdukdX+LlkjHvqpRI6aNAzs5PN7PboQ7qPR8unVETnKkqHM9uRU69urO6E1sfS5riWJbaf+dbfadGsKSe0Oa64bucXu1j87gqu6tMDgOrVq1Mv+kU50LHk8Jo7bwH5W7fF6vr06cH4CX8BYPyEv3D55YlPo+7atbu4TZ3atfHol/iUU9oyd94CCgsL2b37S5YvX0WPHhdX0BkcAULO9MzsduB5wIC3o2LARDO740D7Zqvdu79k7DN/4cbvx7OJ9Rs2UT83h1/eN5Krv3cTdz3wKLu//FeGeinJmjRuxKZNiW/PbNq0hSaNGxVv69u3JyuWv8nUKeMYOvQWAJYte48e3TtTq9ZRNGxYn84XdaJVy+YZ6XtGFHnqpRI6WKY3BDjb3Ue4+zNRGUHiK+NDyr97R55RY59hwHeuoHbtWrH6gsJCVn2Qx3euuJTJfxpFrVpHMWbCpAz1Ug7Ek4ZlU6a8xuntLuKqq4fw6/++FYAZb7zFq6/NYu5bU3l2whPMX7CYwiyf0YwpLEy9VEIHC3pFQEn/xDWLtpXIzIaZ2SIzW/TU+ImH0r8jzvKV7zPyiTF0v2oQz0z6K38c/2eemzyVpo0b0eSYRpxx2skAdO98Pu99kJfh3grA5i2f0bRpYwCaNm3Mlk8//1qbufMW0KbNsTRsWB+AB0Y8Toezu9Oz97WYGatXr6nQPmeSFxWlXCqjg92y8hNgppmt5qsvjB8LnAj8qLSdkr+FufezNZUzBy7F+CcfLl4eNeYZatc6iuuuTszuNW18DB+uXUeb41oyf/FSTmh9bKa6KUleefl1Bg64ht88NIqBA67h5ZenA3DCCa355z8/AuCbZ55OzZo1+PzzrVSpUoXc3Bzy87fSrt0ptGt3Cq/PeDODZ1DBKumwNVUHDHru/pqZfYPEcLZFVL0eWOjulTO3LcGtd49g4ZJlbNu2gy79vsuNQwaQU+9oHnjkSfK3befGW+/m5LbHM/qR+w54nDtvvoHbf/0b9hbspVXzZtx7580AvPHm39I+lpTNMxNGcdGF36JRowZ8tGYRv77nYR58aBTPP/d7Bn/vWj7+eB39r0vcmnTlFb357nevZu/eAv715b+47vobgMQk1JzZLwKwc8cXDPref4Y1vM3yZ2/Ny3naOdsyvZDUan5Bprsgh6Bgz3ory3677rk+5d/ZOnc9W6afkUl6IkNE4gqyO6tV0BORuCwf3iroiUhcyBMZIhKeynorSqoU9EQkTpmeiARFQU9EgpLl9yQq6IlIjL6RISJhUdATkaBk+eyt3pwsInGH+X16Znazma00sxVmNtHMjjKzNma2wMzyzOzPZlYjalszWs+LtrdOOs7Po/r3zaxHWU9PQU9E4g5j0DOzFsB/Ah3c/XSgKtAfeBB4xN1PBLby1fs5hwBbo/pHonaY2anRfqcBPYEnzKxqWU5PQU9EYrywKOWSompALTOrBtQGNgKXAPu+qjUO6Bct943WibZ3MTOL6p9393+7+4dAHom3P6VNQU9E4g5jpufu64GHgY9JBLvtwGJgm7sXRM3W8dWr61oQvbsz2r4daJhcX8I+aVHQE5EYL/KUS/Jb0qMyLPlYZlafRJbWhsRb2OuQGJ5mjGZvRSQujVtWkt+SXoquwIfu/imAmb0InAfkmlm1KJtrSeLlxER/tgLWRcPhHODzpPp9kvdJizI9EYkrSqMc3MdARzOrHV2b6wK8B8wGro7aDAKmRMtTo3Wi7bM88abjqUD/aHa3DdCWxNcZ06ZMT0RivODw3afn7gvMbDLwDlAALCGRGf4v8LyZDY/qxkS7jAEmmFkekE9ixhZ3X2lmk0gEzALgprJ+skKvi5dS6XXxlVtZXxe/7TsXp/w7m/vn2XpdvIhUbnr2VkTCkt1PoSnoiUicMj0RCYsyPREJSfFzEllKQU9EYrL8C5AKeiKyHwU9EQmJMj0RCYqCnogExQsr3UMWaVHQE5EYZXoiEhQvUqYnIgFRpiciQXFXpiciAVGmJyJBKdLsrYiERBMZIhIUBT0RCUo5f0Ei4xT0RCRGmZ6IBEW3rIhIUAo1eysiIVGmJyJB0TU9EQmKZm9FJCjK9EQkKIVFVTLdhXKloCciMRreikhQijR7KyIh0S0rIhIUDW8P0bEnXlbeP0LKya53/pTpLkgGaHgrIkHJ9tnb7D47EUmbp1FSYWa5ZjbZzP7PzFaZ2bfMrIGZzTCz1dGf9aO2ZmaPm1memS0zs7OSjjMoar/azAaV9fwU9EQkpsgt5ZKix4DX3P1k4D+AVcAdwEx3bwvMjNYBegFtozIMeBLAzBoAdwPnAucAd+8LlOlS0BORGHdLuRyMmeUAFwJjEsf2Pe6+DegLjIuajQP6Rct9gfGeMB/INbNmQA9ghrvnu/tWYAbQsyznp6AnIjFFaRQzG2Zmi5LKsP0O1wb4FHjazJaY2VNmVgdo4u4bozabgCbRcgvgk6T910V1pdWnTRMZIhLjpD576+6jgdEHaFINOAv4sbsvMLPH+Goou+8YbmYVdqOMMj0RiSlwS7mkYB2wzt0XROuTSQTBzdGwlejPLdH29UCrpP1bRnWl1adNQU9EYhxLuRz0WO6bgE/M7KSoqgvwHjAV2DcDOwiYEi1PBQZGs7gdge3RMHg60N3M6kcTGN2jurRpeCsiMUWH/5A/Bp41sxrAGmAwiYRrkpkNAdYC347aTgN6A3nA7qgt7p5vZvcCC6N297h7flk6o6AnIjHpXNNL6XjuS4EOJWzqUkJbB24q5ThjgbGH2h8FPRGJKYdM74iioCciMYWHOdM70ijoiUhMlr8tXkFPROKKlOmJSEiy/HV6CnoiEqeJDBEJSpFpeCsiASnMdAfKmYKeiMRo9lZEgqLZWxEJimZvRSQoGt6KSFB0y4qIBKVQmZ6IhESZnogERUFPRIKS+udsKycFPRGJUaYnIkHRY2giEhTdpyciQdHwVkSCoqAnIkHRs7ciEhRd0xORoGj2VkSCUpTlA1wFPRGJ0USGiAQlu/M8BT0R2Y8yPREJSoFld66noCciMdkd8hT0RGQ/2T68rZLpDojIkaUIT7mkysyqmtkSM3slWm9jZgvMLM/M/mxmNaL6mtF6XrS9ddIxfh7Vv29mPcp6fgp6IhLjaZQ0/BewKmn9QeARdz8R2AoMieqHAFuj+keidpjZqUB/4DSgJ/CEmVVN++RQ0BOR/RSlUVJhZi2BS4GnonUDLgEmR03GAf2i5b7ROtH2LlH7vsDz7v5vd/8QyAPOKcv5KeiJSEwhnnJJ0aPAbXwVJxsC29y9IFpfB7SIllsAnwBE27dH7YvrS9gnLQp6IhKTTqZnZsPMbFFSGZZ8LDO7DNji7osr8hwORLO3IhLjaVytc/fRwOgDNDkPuNzMegNHAfWAx4BcM6sWZXMtgfVR+/VAK2CdmVUDcoDPk+r3Sd4nLcr0RCTmcF7Tc/efu3tLd29NYiJilrtfD8wGro6aDQKmRMtTo3Wi7bPc3aP6/tHsbhugLfB2Wc5PmV4JRv5uON16XMRnn+Zzcae+APx+7G85oW0bAHJy6rJ9+066XXAlAD++eSjXDriKwsJCfnX7/cyZ9bfiY1WpUoXX5vyFTRs2M7D/jRV/MgG4a9QzvLloBQ1y6vLSo78AYPvOXdw6ciwbtuTTvHEDHr5lCPWOrs3Tf32DaXMXAlBQWMSH6zfx5tgR5NStwzOvzOaFN/4O7lzZ7TwGXHZx8c94btocnn91LlWrGBe0P52fDuxXYl+yQQW9ZeV24HkzGw4sAcZE9WOACWaWB+STCJS4+0ozmwS8BxQAN7l7md6CpaBXgknPvcTTf3yWx58cUVz3w+/fUrx89/Db2LFjJwDfOOkE+l7Vi84d+9CkWWMm/XUM57XvTVFR4t/BoTcMYPX7/6Ru3aMr9iQCcnnnjvTvdRG/eHx8cd2Yl2ZwbruTGHJld8a8+DpjXnqdmwf0Y3C/rgzu1xWAOQuXM+GV2eTUrcPqjzfwwht/57kHb6V6tarccO8TXNT+dI5tdgxvL/+A2W8vZ/LIO6hRvTqfb9+ZqVOtEOUV8tx9DjAnWl5DCbOv7v4v4JpS9r8PuO9Q+6HhbQnm/30xW7duL3V7n349+OvkaQD06H0JU154lT179vLJ2vV8tOZjvtm+HQDNmjehS/eLeG7CCxXS71B1OO1Eco6uHaubvXAZl198LgCXX3wus95e9rX9Xp23iF7ntwfgw3WbOKNta2rVrEG1qlXpcNqJvLFgKQCTps9lyBXdqFG9OgANc+qW5+lkXAGecqmMyhz0zGzw4exIZdGxU3s++/RzPlyzFoCmzRqzYf2m4u0bNmymabMmANzzwB0Mv+vh4qxPKk7+tp0cUz8HgEa59cjfFs/Ovvz3Hv62dBXdOp4JwInHNuedVXls2/kFX/57D3PfWcnmz7YCsHbjFhav+ifX3fEQg3/1KCvy1lbsyVQwT+O/yuhQMr1fl7YheRp7956th/Ajjjz9rrqUl16YdtB2XaNrgsvefa8CeiUHYmaw33cf3ly0nDNPOp6cunUAOL5lUwb368YP7hnFDfeO4qTWLalSJfHrUVBYxI4vdvHsAz/jpwP78bPfjiVxbT07He6bk480B7ymZ2ZfHxNEm4Ampe2XPI3dLPfUrPm/o2rVqvTu05Uenb+65LBp4xaat2havN68eRM2bdxMj16X0L3XxXTpfiE1a9akbt06/O4PD/KjH9yeia4Hp0FuXT7dup1j6ufw6dbtNNhvSPravMX0uqB9rO7Krp24smsnAB57dipNGuYC0KRhLl3OPRMzo13b1lQxY+uOL752zGxRWTO4VB0s02sCDAT6lFA+L9+uHXku7Pwt8lZ/yMYNm4vrpr86m75X9aJGjeq0Oq4FbU44jiWLl3P/PY/Q/rRLOOeMbvxwyC3Me2uBAl4F6tyhHVNnLwBg6uwFXHz2GcXbdu76kkXv5cXqgOIJio2f5jNz/rv0vqADAJeccwYLV3wAwEcbNrO3oID69bJ3YiroTA94BTja3Zfuv8HM5pRLj44ATzz1EJ3OP4cGDXNZvHIWD4/4HRMnvEjfq3oVT2Ds88H/5fHyS9N5c8HLFBQUcufPhusaXgW7beTTLFq5mm07v6Dr0F9y43d6M+TKbvzst2N5aeY/aHZMAx6+5fvF7WcteJdO/3EytY+qGTvOTx96iu07d1GtalXuHPpt6tVJTI5cccm3uOuJZ7niJ/dRvVpVhv94QGLInKUKs3joDmDlfW0im4a3oflo3mOZ7oIcgpqndytTZL7uuCtS/p19bu1LlS766z49EYnJ9mt6CnoiEpPtF2cU9EQkRh/7FpGgaHgrIkHJ9tlbBT0RidHwVkSCookMEQmKrumJSFA0vBWRoGTzG2RAQU9E9pPGpx0rJQU9EYnR8FZEgqLhrYgERZmeiARFt6yISFD0GJqIBEXDWxEJioKeiARFs7ciEhRleiISFM3eikhQCj27Xy6loCciMbqmJyJB0TU9EQmKrumJSFCKsnx4WyXTHRCRI4un8d/BmFkrM5ttZu+Z2Uoz+6+ovoGZzTCz1dGf9aN6M7PHzSzPzJaZ2VlJxxoUtV9tZoPKen4KeiISU+hFKZcUFAC3uPupQEfgJjM7FbgDmOnubYGZ0TpAL6BtVIYBT0IiSAJ3A+cC5wB37wuU6VLQE5GYIveUy8G4+0Z3fyda3gmsAloAfYFxUbNxQL9ouS8w3hPmA7lm1gzoAcxw93x33wrMAHqW5fx0TU9EYsprIsPMWgPfBBYATdx9Y7RpE9AkWm4BfJK027qorrT6tCnTE5GYdDI9MxtmZouSyrCSjmlmRwMvAD9x9x3J2zxxY2CFzZ4o0xORmHQyPXcfDYw+UBszq04i4D3r7i9G1ZvNrJm7b4yGr1ui+vVAq6TdW0Z164HO+9XPSbmjSZTpiUhMoRemXA7GzAwYA6xy95FJm6YC+2ZgBwFTkuoHRrO4HYHt0TB4OtDdzOpHExjdo7q0KdMTkZjD/BjaecAAYLmZLY3q7gRGAJPMbAiwFvh2tG0a0BvIA3YDg6M+5ZvZvcDCqN097p5flg4p6IlIzOF8DM3d5wFWyuYuJbR34KZSjjUWGHuofVLQE5EYvXBARIKS7Y+hKeiJSIxeOCAiQdFLREUkKLqmJyJB0TU9EQmKMj0RCYpeFy8iQVGmJyJB0eytiARFExkiEhQNb0UkKHoiQ0SCokxPRIKS7df0LNujenkzs2HRK7OlEtLfX3j0uvhDV+KHUKTS0N9fYBT0RCQoCnoiEhQFvUOn60GVm/7+AqOJDBEJijI9EQmKgt4hMLOeZva+meWZ2R2Z7o+kzszGmtkWM1uR6b5IxVLQKyMzqwqMAnoBpwLXmtmpme2VpOFPQM9Md0IqnoJe2Z0D5Ln7GnffAzwP9M1wnyRF7v4WkJ/pfkjFU9AruxbAJ0nr66I6ETmCKeiJSFAU9MpuPdAqab1lVCciRzAFvbJbCLQ1szZmVgPoD0zNcJ9E5CAU9MrI3QuAHwHTgVXAJHdfmdleSarMbCLwD+AkM1tnZkMy3SepGHoiQ0SCokxPRIKioCciQVHQE5GgKOiJSFAU9EQkKAp6IhIUBT0RCYqCnogE5f8B5smvhsQwjFAAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 360x288 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "matplotlib.use('Agg')\n",
    "%pylab inline \n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sn\n",
    "import pandas as pd\n",
    "from sklearn.metrics import confusion_matrix\n",
    "\n",
    "cm = confusion_matrix(y_true, y_pred)\n",
    "cm.shape\n",
    "\n",
    "df_cm = pd.DataFrame(cm)\n",
    "plt.figure(figsize = (5,4))\n",
    "sn.heatmap(df_cm, annot=True,fmt='d')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "source": [
    "Because of the limitation of ariticle length, not all the results of optional models can be shown respectively. Please try other provided optional models to see the results. If you are interested in optimizing the results, try different training parameters which may make inpacts on the result, such as the max sequence length, batch size, training epochs, preprocessing schemes, optimization methods and so on. Among the models, CNN training would be much quicker. Note that the LSTM and it variants (eg. GRU) are difficult to train, even a unsuitable batch size may cause the model not converge. In addition it is prone to overfitting, please try different dropout threshold and/or add regularizers."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
